#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import cloudpickle
import datetime
import os
import socket
import sys
import threading
import zmq
import parl
import time
import glob
import multiprocessing as mp

from parl.utils import to_str, to_byte, get_ip_address, logger, isnotebook
from parl.remote.utils import get_subfiles_recursively
from parl.remote import remote_constants
from parl.remote.grpc_heartbeat import HeartbeatServerThread, HeartbeatServerProcess
from parl.remote.utils import get_version


class Client(object):
    """Base class for the remote client.

    For each training task, there is a global client in the cluster which
    submits jobs to the master node. Different `@parl.remote_class` objects
    connect to the same global client in a training task.

    Attributes:
        submit_job_socket (zmq.Context.socket): A socket which submits job to
                                                the master node.
        pyfiles (bytes): A serialized dictionary containing the code of python
                         files in local working directory.
        executable_path (str): File path of the executable python script.
        start_time (time): A timestamp to record the start time of the program.

    """

    def __init__(self, master_address, process_id, distributed_files=[]):
        """
        Args:
            master_addr (str): IP address of the master node.
            job_heartbeat_server_addr(str): Server address for heartbeat detection from jobs.
            process_id (str): Process id in which client is created. Should use os.getpid() to get the process id.
            distributed_files (list): A list of files to be distributed at all remote instances(e,g. the configuration
                                      file for initialization) .
        """
        self.dead_job_queue = mp.Queue()
        self.client_is_alive = mp.Value('i', True)
        self._create_heartbeat_server()
        th = threading.Thread(target=self._update_job_status, args=(self.dead_job_queue, ))
        th.setDaemon(True)
        th.start()
        self.master_address = master_address
        self.process_id = process_id
        self.ctx = zmq.Context()
        self.lock = threading.Lock()
        self.log_monitor_url = None
        self.threads = []
        self.executable_path = self.get_executable_path()
        self._create_sockets(master_address)
        self.connected_to_master = True
        self.check_env_consistency()
        self.instance_count = 0
        self.instance_id_to_job = dict()

        thread = threading.Thread(target=self._update_client_status_to_master)
        thread.setDaemon(True)
        thread.start()
        self.threads.append(thread)

        self.pyfiles = self.read_local_files(distributed_files)

    def destroy(self):
        """Destructor function"""
        self.connected_to_master = False
        self.dead_job_queue.put('exit')
        self.master_heartbeat_thread.exit()
        for th in self.threads:
            th.join()
        self.ctx.destroy()
        self.client_is_alive.value = False
        self.job_heartbeat_process.join()

    def get_executable_path(self):
        """Return current executable path."""
        mod = sys.modules['__main__']
        if hasattr(mod, '__file__'):
            executable_path = os.path.abspath(mod.__file__)
        else:
            executable_path = os.getcwd()
        executable_path = executable_path[:executable_path.rfind('/')]
        return executable_path

    def read_local_files(self, distributed_files=[]):
        """Read local python code and store them in a dictionary, which will
        then be sent to the job.

        Args:
            distributed_files (list): A list of files to be distributed at all
                                      remote instances(e,g. the configuration
                                      file for initialization) . RegExp of file
                                      names is supported. 
                                      e.g. 
                                          distributed_files = ['./*.npy', './test*']
                                                                             
        Returns:
            A cloudpickled dictionary containing the python code in current
            working directory.
        """
        pyfiles = dict()
        pyfiles['python_files'] = {}
        pyfiles['other_files'] = {}

        user_files = []
        user_empty_subfolders = []

        for distributed_file in distributed_files:
            parsed_list = glob.glob(distributed_file)
            if not parsed_list:
                raise ValueError("no local file is matched with '{}', please check your input".format(distributed_file))
            for pathname in parsed_list:
                if os.path.isdir(pathname):
                    pythonfiles, otherfiles, emptysubfolders = get_subfiles_recursively(pathname)
                    user_files.extend(pythonfiles)
                    user_files.extend(otherfiles)
                    user_empty_subfolders.extend(emptysubfolders)
                else:
                    user_files.append(pathname)

        if isnotebook():
            main_folder = './'
        else:
            main_file = sys.argv[0]
            main_folder = './'
            sep = os.sep
            if sep in main_file:
                main_folder = sep.join(main_file.split(sep)[:-1])
        code_files = filter(lambda x: x.endswith('.py'), os.listdir(main_folder))

        for file_name in code_files:
            file_path = os.path.join(main_folder, file_name)
            assert os.path.exists(file_path)
            with open(file_path, 'rb') as code_file:
                code = code_file.read()
                pyfiles['python_files'][file_name] = code

        for file_name in set(user_files):
            assert os.path.exists(file_name)
            assert not os.path.isabs(file_name), "[XPARL] Please do not distribute a file with absolute path."
            with open(file_name, 'rb') as f:
                content = f.read()
                pyfiles['other_files'][file_name] = content

        pyfiles['empty_subfolders'] = set(user_empty_subfolders)
        return cloudpickle.dumps(pyfiles)

    def _create_sockets(self, master_address):
        """ Each client has 1 sockets as start:

        (1) submit_job_socket: submits jobs to master node.
        """

        # submit_job_socket: submits job to master
        self.submit_job_socket = self.ctx.socket(zmq.REQ)
        self.submit_job_socket.linger = 0
        self.submit_job_socket.setsockopt(zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000)
        self.submit_job_socket.connect("tcp://{}".format(master_address))
        self.start_time = time.time()

        def master_heartbeat_exit_callback_func():
            logger.warning("[Client] Cannot connect to the master. " "Please check if it is still alive.")
            logger.warning("Client exit replying heartbeat for master.")
            self.connected_to_master = False

        self.master_heartbeat_thread = HeartbeatServerThread(
            heartbeat_exit_callback_func=master_heartbeat_exit_callback_func)
        self.master_heartbeat_thread.setDaemon(True)
        self.master_heartbeat_thread.start()
        self.reply_master_heartbeat_address = self.master_heartbeat_thread.get_address()
        self.threads.append(self.master_heartbeat_thread)

        self.client_id = self.reply_master_heartbeat_address.replace(':', '_') + \
                            '_' + str(int(time.time()))

        # check if the master is connected properly
        try:
            self.submit_job_socket.send_multipart([
                remote_constants.CLIENT_CONNECT_TAG,
                to_byte(self.reply_master_heartbeat_address),
                to_byte(socket.gethostname()),
                to_byte(self.client_id),
            ])
            message = self.submit_job_socket.recv_multipart()
            self.log_monitor_url = to_str(message[1])
        except zmq.error.Again as e:
            logger.warning("[Client] Can not connect to the master, please "
                           "check if master is started and ensure the input "
                           "address {} is correct.".format(master_address))
            self.connected_to_master = False
            raise Exception("Client can not connect to the master, please "
                            "check if master is started and ensure the input "
                            "address {} is correct.".format(master_address))

    def _update_job_status(self, dead_job_queue):
        while True:
            instance_id = dead_job_queue.get()
            # the client calls the destroy function
            if isinstance(instance_id, str) and instance_id == 'exit':
                break
            logger.error("[Client] lost connection with a remote instance. ID: {}".format(instance_id))
            job_is_alive = self.instance_id_to_job[instance_id]
            job_is_alive.value = False

    def check_env_consistency(self):
        '''Verify that the parl & python version as well as some other packages in 'worker' process
            matches that of the 'master' process'''
        self.submit_job_socket.send_multipart([remote_constants.CHECK_VERSION_TAG])
        message = self.submit_job_socket.recv_multipart()
        tag = message[0]
        if tag == remote_constants.NORMAL_TAG:
            client_parl_version = parl.__version__
            client_python_version_major = str(sys.version_info.major)
            client_python_version_minor = str(sys.version_info.minor)
            assert client_parl_version == to_str(message[1]) and client_python_version_major == to_str(message[2])\
                and client_python_version_minor == to_str(message[3]),\
                '''Version mismatch: the 'master' is of version 'parl={}, python={}.{}'. However, \
                'parl={}, python={}.{}'is provided in your environment.'''.format(
                        to_str(message[1]), to_str(message[2]), to_str(message[3]),
                        client_parl_version, client_python_version_major, client_python_version_minor
                    )
        else:
            raise NotImplementedError

    def _update_client_status_to_master(self):
        while self.connected_to_master:
            elapsed_time = datetime.timedelta(seconds=int(time.time() - self.start_time))
            client_status = {
                'file_path': self.executable_path,
                'actor_num': self.actor_num.value,
                'time': str(elapsed_time),
                'log_monitor_url': self.log_monitor_url
            }

            self.lock.acquire()
            try:
                self.submit_job_socket.send_multipart([
                    remote_constants.CLIENT_STATUS_UPDATE_TAG,
                    to_byte(self.reply_master_heartbeat_address),
                    cloudpickle.dumps(client_status)
                ])
                message = self.submit_job_socket.recv_multipart()
            except zmq.error.Again as e:
                self.connected_to_master = False
            finally:
                self.lock.release()

            time.sleep(remote_constants.HEARTBEAT_INTERVAL_S)

    def _check_job(self, job_ping_address, max_memory, gpu):
        """ 
        Check if this job is still alive before establishing connection with it.
        Return: instance_id (int): an unique isntance id. -1 if the job is not ready for connection.
        """
        # job_ping_socket: sends ping signal to job
        job_ping_socket = self.ctx.socket(zmq.REQ)
        job_ping_socket.linger = 0
        job_ping_socket.setsockopt(zmq.RCVTIMEO, int(0.9 * 1000))
        job_ping_socket.connect("tcp://" + job_ping_address)
        instance_id = self._generate_instance_id()
        try:
            job_ping_socket.send_multipart([
                remote_constants.HEARTBEAT_TAG,
                to_byte(self.job_heartbeat_server_addr),
                to_byte(str(max_memory)),
                to_byte(gpu),
                to_byte(instance_id)
            ], )
            job_ping_socket.recv_multipart()
        except zmq.error.Again:
            logger.error(
                "[Client] connects to a finished job, will try again, job_ping_address:{}".format(job_ping_address))
            instance_id = -1
        finally:
            job_ping_socket.close(0)
        return instance_id

    def _create_heartbeat_server(self):
        """ Create the grpc-based heartbeat server at the subprocess.
        """
        job_heartbeat_port = mp.Value('i', 0)
        self.actor_num = mp.Value('i', 0)
        self.job_heartbeat_process = HeartbeatServerProcess(job_heartbeat_port, self.actor_num, 
                                         self.client_is_alive, self.dead_job_queue)
        self.job_heartbeat_process.daemon = True
        self.job_heartbeat_process.start()
        assert job_heartbeat_port.value != 0, "fail to initialize heartbeat server for jobs."
        self.job_heartbeat_server_addr = "{}:{}".format(get_ip_address(), job_heartbeat_port.value)

    def _generate_instance_id(self):
        """Return an unique instance id for the remote instance"""
        self.instance_count += 1
        unique_id = f"{self.instance_count:05}"
        return unique_id

    def submit_job(self, max_memory, n_gpu, job_is_alive):
        """Send a job to the Master node.

        When a `@parl.remote_class` object is created, the global client
        sends a job to the master node. Then the master node will allocate
        a vacant job from its job pool to the remote object.

        Args:
            max_memory (float): Maximum memory (MB) can be used by each remote
                                instance, the unit is in MB and default value is
                                none(unlimited).
            n_gpu (int): Number of GPUs can used in this remote instance.
        Returns:
            An ``InitializedJob`` that has information about available job address.
        """
        if self.connected_to_master:

            while True:
                self.lock.acquire()
                n_cpu = 0 if n_gpu > 0 else 1
                self.submit_job_socket.send_multipart([
                    remote_constants.CLIENT_SUBMIT_TAG,
                    to_byte(self.reply_master_heartbeat_address),
                    to_byte(self.client_id),
                    to_byte(str(n_cpu)),
                    to_byte(str(n_gpu))
                ])
                message = self.submit_job_socket.recv_multipart()
                self.lock.release()
                tag = message[0]
                if tag == remote_constants.NORMAL_TAG:
                    job_info = cloudpickle.loads(message[1])
                    job_ping_address = job_info.ping_heartbeat_address

                    self.lock.acquire()
                    instance_id = self._check_job(job_ping_address, max_memory, job_info.allocated_gpu.gpu)
                    self.lock.release()
                    if instance_id != -1:
                        self.instance_id_to_job[instance_id] = job_is_alive
                        return job_info
                # no vacant CPU resources, cannot submit a new job
                elif tag == remote_constants.CPU_TAG:
                    # wait 1 second to avoid requesting in a high frequency.
                    time.sleep(1)
                    return None
                # no vacant GPU resources, cannot submit a new job
                elif tag == remote_constants.GPU_TAG:
                    # wait 5 second to avoid requesting in a high frequency.
                    time.sleep(1)
                    return None
                elif tag == remote_constants.REJECT_GPU_JOB_TAG:
                    error_message = "[Client] Request fails. It is not allowed to request GPU resource from a CPU cluster."
                    logger.error(error_message)
                    raise Exception(error_message)
                elif tag == remote_constants.REJECT_CPU_JOB_TAG:
                    error_message = "[Client] Request fails. It is not allowed to request CPU resource from a GPU cluster."
                    logger.error(error_message)
                    raise Exception(error_message)
                elif tag == remote_constants.REJECT_INVALID_GPU_JOB_TAG:
                    error_message = "[Client] request {} GPUs, but rejected.".format(n_gpu)
                    logger.error(error_message)
                    raise Exception(error_message)
                else:
                    raise NotImplementedError
        else:
            raise Exception("Client can not submit job to the master, please check if master is connected.")
        return None


GLOBAL_CLIENT = None


def connect(master_address, distributed_files=[]):
    """Create a global client which connects to the master node.

    .. code-block:: python

        parl.connect(master_address='localhost:1234')

    Args:
        master_address (str): The address of the Master node to connect to.
        distributed_files (list): A list of files to be distributed at all 
                                  remote instances(e,g. the configuration
                                  file for initialization) .

    Raises:
        Exception: An exception is raised if the master node is not started.
    """

    assert len(master_address.split(":")) == 2, "Please input address in " +\
        "{ip}:{port} format"
    assert isinstance(distributed_files, list), "`distributed_files` should be a list."

    global GLOBAL_CLIENT
    addr = master_address.split(":")[0]
    cur_process_id = os.getpid()
    if GLOBAL_CLIENT is None:
        GLOBAL_CLIENT = Client(master_address, cur_process_id, distributed_files)
    else:
        if GLOBAL_CLIENT.process_id != cur_process_id:
            GLOBAL_CLIENT = Client(master_address, cur_process_id, distributed_files)
    logger.info("Remote actors log url: {}".format(GLOBAL_CLIENT.log_monitor_url))


def get_global_client():
    """Get the global client.

    To support process-based programming, we will create a new global client in the new process.

    Returns:
        The global client.
    """
    global GLOBAL_CLIENT
    assert GLOBAL_CLIENT is not None, "Cannot get the client to submit the" +\
        " job, have you connected to the cluster by calling " +\
        "parl.connect(master_ip, master_port)?"

    cur_process_id = os.getpid()
    if GLOBAL_CLIENT.process_id != cur_process_id:
        GLOBAL_CLIENT = Client(GLOBAL_CLIENT.master_address, cur_process_id)
    return GLOBAL_CLIENT


def disconnect():
    """Disconnect the global client from the master node."""
    global GLOBAL_CLIENT
    if GLOBAL_CLIENT is not None:
        GLOBAL_CLIENT.destroy()
        GLOBAL_CLIENT = None
        logger.info("The client is disconneced to the master node.")
    else:
        logger.info("No client to be released. Please make sure that you have called `parl.connect`")
